// Copyright 2018-2019 The MathWorks, Inc.

#ifdef BUILDING_LIBMWCOLLISIONCODEGEN
    #include "collisioncodegen/collisioncodegen_ccdExtensions.hpp"
    #include <ccd/ccd_simplex.h>
#else // during portable codegen, all files are placed in a flat directory
    #include "collisioncodegen_ccdExtensions.hpp"
    #include <ccd_simplex.h>
#endif

#include <vector>
#include <limits>

using namespace shared_robotics;

void barycentricCoordinates(const ccd_vec3_t * a, const ccd_vec3_t * b, const ccd_vec3_t * c, const ccd_vec3_t * p, ccd_vec3_t * lambda);

ccd_real_t shared_robotics::distanceToOrigin(ccd_simplex_t * simplex, ccd_vec3_t *& closestPoint)
{
    // find the distance from origin to the given simplex

    ccd_real_t distSq;
    ccd_real_t distance = 0;

    switch (ccdSimplexSize(simplex)) 
    {
    case 1: // point
    {
        closestPoint->v[0] = simplex->ps[0].v.v[0];
        closestPoint->v[1] = simplex->ps[0].v.v[1];
        closestPoint->v[2] = simplex->ps[0].v.v[2];
        distSq = ccdVec3Len2(closestPoint);
        distance = CCD_SQRT(distSq);
        break;
    }
    case 2: // line segment
    {
        ccd_vec3_t p0 = simplex->ps[0].v;
        ccd_vec3_t p1 = simplex->ps[1].v;
        distSq = ccdVec3PointSegmentDist2(ccd_vec3_origin, &p0, &p1, closestPoint);
        distance = CCD_SQRT(distSq);
        break;
    }
    case 3: // triangle
    {
        ccd_vec3_t p0 = simplex->ps[0].v;
        ccd_vec3_t p1 = simplex->ps[1].v;
        ccd_vec3_t p2 = simplex->ps[2].v;
        distSq = ccdVec3PointTriDist2(ccd_vec3_origin, &p0, &p1, &p2, closestPoint);
        distance = CCD_SQRT(distSq);
        break;
    }
    default: // simplex can be a tetrahedron. 
             // The simplex generated by GJK is not necessarily sufficiently refined/reduced to report the actual distance
    {
        // cycle through all faces in the tetrahedron and find the one with shortest distance to origin

        const std::size_t indices[4][3] = { { 3,1,2 },{ 0,3,2 },{ 0,1,3 },{ 0,1,2 }}; // >> combnk([0 1 2 3], 3)
        ccd_vec3_t closestPointCandidate;

        ccd_real_t bestDistSq = -1;
        std::size_t bestIdx = 0;
        for (std::size_t i = 0; i <= 3; i++)
        {
            ccd_vec3_t p0 = simplex->ps[indices[i][0]].v;
            ccd_vec3_t p1 = simplex->ps[indices[i][1]].v;
            ccd_vec3_t p2 = simplex->ps[indices[i][2]].v;

            distSq = ccdVec3PointTriDist2(ccd_vec3_origin, &p0, &p1, &p2, &closestPointCandidate);
            if (i == 0)
            {
                bestDistSq = distSq;
                bestIdx = 0;
                closestPoint->v[0] = closestPointCandidate.v[0];
                closestPoint->v[1] = closestPointCandidate.v[1];
                closestPoint->v[2] = closestPointCandidate.v[2];
            }
            else if (distSq < bestDistSq)
            {
                bestDistSq = distSq;
                bestIdx = i;
                closestPoint->v[0] = closestPointCandidate.v[0];
                closestPoint->v[1] = closestPointCandidate.v[1];
                closestPoint->v[2] = closestPointCandidate.v[2];
            }
        }
        distance = CCD_SQRT(bestDistSq);
        // also reduce tetrahedron to triangle
        if (bestIdx < 3)
        {
            simplex->ps[bestIdx] = simplex->ps[3];
        }
        simplex->last = 2;
    }
    }

    return distance;
}


void shared_robotics::extractWitnessPoints(const ccd_simplex_t *simplex, const ccd_vec3_t *closestPoint, ccd_vec3_t *p1, ccd_vec3_t *p2)
{
    // witness points: a pair of points from the two objects that realize the minimum separation distance between them.
    switch (ccdSimplexSize(simplex))
    {
    case 1:// point
    {
        *p1 = simplex->ps[0].v1;
        *p2 = simplex->ps[0].v2;
        break;
    }
    case 2: // line segment
    {
        interpolateBetweenTwoSupportPoints(simplex, closestPoint, p1, p2);
        break;
    }
    case 3: // triangle
    {
        interpolateAmongThreeSupportPoints(simplex, closestPoint, p1, p2);
        break;
    }
    }

}

void shared_robotics::interpolateBetweenTwoSupportPoints(const ccd_simplex_t *& simplex, const ccd_vec3_t *& closestPoint, ccd_vec3_t *& p1, ccd_vec3_t *& p2)
{
    // interpolate between two support points m, n
    //
    // o------------o-------------o
    // m(m1,m2)     p(p1,p2)      n(n1,n2)
    //
    ccd_vec3_t m, n, p, dp, dn;
    ccdVec3Copy(&m, &simplex->ps[0].v);
    ccdVec3Copy(&n, &simplex->ps[1].v);
    ccdVec3Copy(&p, closestPoint);

    ccdVec3Sub2(&dp, &p, &m);
    ccdVec3Sub2(&dn, &n, &m);
    auto dpm = CCD_SQRT(ccdVec3Len2(&dp));
    auto dnm = CCD_SQRT(ccdVec3Len2(&dn));

    if (dnm < std::numeric_limits<ccd_real_t>::epsilon())
    {
        *p1 = simplex->ps[0].v1;
        *p2 = simplex->ps[0].v2;
    }
    else
    {
        ccd_vec3_t m1, m2, n1, n2, tmp, tmp2;
        ccdVec3Copy(&m1, &simplex->ps[0].v1);
        ccdVec3Copy(&m2, &simplex->ps[0].v2);
        ccdVec3Copy(&n1, &simplex->ps[1].v1);
        ccdVec3Copy(&n2, &simplex->ps[1].v2);

        auto s = dpm / dnm;
        ccdVec3Sub2(&tmp, &n1, &m1);
        ccdVec3Scale(&tmp, s);
        ccdVec3Add(&tmp, &m1);
        
        ccdVec3Copy(p1, &tmp);

        ccdVec3Sub2(&tmp2, &n2, &m2);
        ccdVec3Scale(&tmp2, s);
        ccdVec3Add(&tmp2, &m2);
        
        ccdVec3Copy(p2, &tmp2);
    }
}

/// convert a point's position in global coordinates to the barycentric coordinates defined by three points in global coordinates (internal use only).
void barycentricCoordinates(const ccd_vec3_t * a, const ccd_vec3_t * b, const ccd_vec3_t * c, const ccd_vec3_t * p, ccd_vec3_t * lambda)
{
    // a, b, c must not be co-linear
    ccd_vec3_t ab, ac, bc, bp, cp, n, na, nb, neg_ac;
    ccdVec3Sub2(&ab, b, a);
    ccdVec3Sub2(&ac, c, a);
    ccdVec3Sub2(&bc, c, b);

    ccdVec3Sub2(&bp, p, b);
    ccdVec3Sub2(&cp, p, c);

    ccdVec3Cross(&n, &ab, &ac);
    ccdVec3Cross(&na, &bc, &bp);
    ccdVec3Copy(&neg_ac, &ac);
    ccdVec3Scale(&neg_ac, -1);
    ccdVec3Cross(&nb, &neg_ac, &cp);


    ccd_real_t nn = ccdVec3Dot(&n, &n);
    ccd_real_t alpha = ccdVec3Dot(&n, &na) / nn;
    ccd_real_t beta = ccdVec3Dot(&n, &nb) / nn;

    ccdVec3Set(lambda, alpha, beta, 1 - alpha - beta);
}

void shared_robotics::interpolateAmongThreeSupportPoints(const ccd_simplex_t *& simplex, const ccd_vec3_t *& closestPoint, ccd_vec3_t *& p1, ccd_vec3_t *& p2)
{
    /* interpolate between three support points a, b, c

               o a(a1,a2)
              / \
             /   \
            /     \
           /       \
          /  o      \
         /   p(p1,p2)\
        o-------------o
        b(b1,b2)       c(c1,c2)
    */

    ccd_vec3_t a, b, c, p, baVec, caVec, cbVec, crprod;
    ccdVec3Copy(&a, &simplex->ps[0].v);
    ccdVec3Copy(&b, &simplex->ps[1].v);
    ccdVec3Copy(&c, &simplex->ps[2].v);
    ccdVec3Copy(&p, closestPoint);
    
    ccdVec3Sub2(&baVec, &b, &a);
    ccdVec3Sub2(&caVec, &c, &a);
    ccdVec3Cross(&crprod, &baVec, &caVec);

    ccdVec3Sub2(&cbVec, &c, &b);

    if ( CCD_SQRT(ccdVec3Len2(&crprod)) < std::numeric_limits<ccd_real_t>::epsilon()) // if triangle area is zero
    {
        ccd_simplex_t simplexTmp;
        ccd_real_t ab = CCD_SQRT(ccdVec3Len2(&baVec));
        ccd_real_t ac = CCD_SQRT(ccdVec3Len2(&caVec));
        ccd_real_t bc = CCD_SQRT(ccdVec3Len2(&cbVec));

        simplexTmp.last = 1;
        if (ab >= ac && ab >= bc)
        {
            simplexTmp.ps[0] = simplex->ps[0];
            simplexTmp.ps[1] = simplex->ps[1];
        }
        else if (ac > ab && ac > bc)
        {
            simplexTmp.ps[0] = simplex->ps[0];
            simplexTmp.ps[1] = simplex->ps[2];
        }
        else
        {
            simplexTmp.ps[0] = simplex->ps[1];
            simplexTmp.ps[1] = simplex->ps[2];
        }

        const ccd_simplex_t * simplexPtr = &simplexTmp;
        interpolateBetweenTwoSupportPoints(simplexPtr, closestPoint, p1, p2);
    }
    else
    {
        ccd_vec3_t lambda = { {0, 0, 0} };
        barycentricCoordinates(&a, &b, &c, &p, &lambda);

        ccd_vec3_t a1, a2, b1, b2, c1, c2;
        ccd_vec3_t tmp1 = { {0, 0, 0} };
        ccd_vec3_t tmp2 = { {0, 0, 0} };
        ccdVec3Copy(&a1, &simplex->ps[0].v1);
        ccdVec3Copy(&a2, &simplex->ps[0].v2);

        ccdVec3Copy(&b1, &simplex->ps[1].v1);
        ccdVec3Copy(&b2, &simplex->ps[1].v2);

        ccdVec3Copy(&c1, &simplex->ps[2].v1);
        ccdVec3Copy(&c2, &simplex->ps[2].v2);

        ccdVec3Scale(&a1, lambda.v[0]);
        ccdVec3Scale(&b1, lambda.v[1]);
        ccdVec3Scale(&c1, lambda.v[2]);

        ccdVec3Add(&tmp1, &a1);
        ccdVec3Add(&tmp1, &b1);
        ccdVec3Add(&tmp1, &c1);

        ccdVec3Copy(p1, &tmp1);

        ccdVec3Scale(&a2, lambda.v[0]);
        ccdVec3Scale(&b2, lambda.v[1]);
        ccdVec3Scale(&c2, lambda.v[2]);

        ccdVec3Add(&tmp2, &a2);
        ccdVec3Add(&tmp2, &b2);
        ccdVec3Add(&tmp2, &c2);

        ccdVec3Copy(p2, &tmp2);
    }

}



ccd_real_t shared_robotics::ccdDistance(const void *obj1, const void *obj2, const ccd_t *ccd,
                ccd_vec3_t* p1, ccd_vec3_t* p2)
{
    ccd_simplex_t simplex; // double check
    ccd_vec3_t closestPoint; // placeholder
    ccd_real_t distance = -1;
    ccd_real_t lastDistance;

    ccd_simplex_t * simplexPtr = &simplex;
    ccd_vec3_t *closestPointPtr = &closestPoint; 

    if (__ccdGJK(obj1, obj2, ccd, &simplex) == 0) // __ccdGJK return -1 if intersection is not found or 0 if intersection is found
    {
        return -CCD_ONE;
    }
    
    int maxIterations = static_cast<int>(ccd->max_iterations);
    lastDistance = CCD_REAL_MAX;

    // if the input simplex comes out from __ccdGJK,
    // in many cases, the for-loop should exit after the first iteration
    for (int i = 0; i < maxIterations; i++)  
    {

        distance = distanceToOrigin(simplexPtr, closestPointPtr);

        if (CCD_FABS(lastDistance - distance) < ccd->dist_tolerance) // if no further improvement
        {
            // interpolate to populate p1, p2
            extractWitnessPoints(simplexPtr, closestPointPtr, p1, p2);
            return distance;
        }
        
        lastDistance = distance;

        // further examination
        // validate the closest point through support direction
        // compute the support direction based on the direction as determined by the current closestPoint 
        ccd_vec3_t dir = { -closestPointPtr->v[0], -closestPointPtr->v[1], -closestPointPtr->v[2] };
        ccd_support_t closestPointNew;
        __ccdSupport(obj1, obj2, &dir, ccd, &closestPointNew);
        distance = CCD_SQRT(ccdVec3Len2(&closestPointNew.v));

        // if the acquired support point deviates from the closestPoint candidate
        // add the support point to the simplex and rerun distance function
        // but we need to first check if the new support point coincides with any existing ones
        
        int samePoint = 0;
        for (std::size_t q = 0; q <= static_cast<std::size_t>(simplex.last); q++)
        {
            samePoint = samePoint || ccdVec3Eq(&closestPointNew.v, &simplex.ps[q].v);
        }
        
        if (!samePoint)
        {
            simplex.ps[simplex.last + 1] = closestPointNew;
            simplex.last += 1;
        }

        // for debugging, uncomment the following lines to step through the simplex migration 
        //std::cout << "----------Iter " << i << "-------------" << std::endl;
        //for (std::size_t q = 0; q <= simplex.last; q++)
        //{
        //    std::cout << "simplex ps[" << q << "] = [ " << simplex.ps[q].v.v[0] << " " << simplex.ps[q].v.v[1] << " " << simplex.ps[q].v.v[2] << " ] " << std::endl;
        //}

    }

    // reach iteration limit
    extractWitnessPoints(simplexPtr, closestPointPtr, p1, p2);
    return distance;
    
}
